classdef ProtoDef < handle
    %PROTODEF Helper class used to access protocol data strucrure
    %   Detailed explanation goes here
    
    properties(SetAccess = private , GetAccess = public  )
        protocolVersion = 4.2;
        TrialsDefinition = {}        
        TrialSequence = {}
        Sounds = {};
        maxProtocolDuration = 0.0;
        finishLastTrial = 0;
    end
    
    properties(SetAccess = private , GetAccess = private  )
        SoundList = {}
    end
    
    methods(Static)
        function trial = createEmptyTrial(trialName)
        try
            import ProtoDef.ProtoDef
            trial = struct();
            trial.name = trialName;

            % Create two default Periods
            trial.periods = {};
            %trial.periods{end+1} = ProtoDef.createEmptyPeriod('Start');
            trial.periods{end+1} = ProtoDef.createEmptyPeriod('ITI');
        catch err
            warning('Creating new Trial failed! %s\n', err.message)
        end
        end
        
        function name = getTrialName()
        try
            prompt={'Name of the new Trial:'};
            title='Trial Name';
            numlines=1;
            defaultanswer={''};
            answer=inputdlg(prompt,title,numlines,defaultanswer);
            name = answer{:};
        catch err
            warn('Getting name failed! %s\n', err.message)
        end
        end
        
        function period = createEmptyPeriod(periodName)
        try
            period = struct();
            period.name = periodName;
            period.minDuration = 0;
            period.maxDuration = 0;
            period.houselight = 0;            
            period.bluecuelight = 0;
            period.yellowcuelight = 0;
            period.shock = 0;
            period.valve = 0;
            period.valveOpenDuration = 0.5;
            period.sound = -1;
            period.nextPeriod = '';
            period.condJumpAction = -1;
            period.condJumpPeriod = '';            
        catch err
            warn('Creating new Trial failed! %s\n', err.message)
        end
        end
        
    
        function newStruct = removeStructElement(oldStruct, idx)
            nPeriods = length(oldStruct);
            newStruct = cell(1, nPeriods-1);
            switch idx
                case 1 %The first Element
                    newStruct(1:end) = oldStruct(2:end);
                case nPeriods %The last element
                    newStruct(1:end) = oldStruct(1:end-1);
                otherwise
                    newStruct(1:idx-1) = oldStruct(1:idx-1);
                    newStruct(idx:end) = oldStruct(idx+1:end);
            end            
        end
    end
    
    methods
        function this = ProtoDef(protocolFile)
            try
                if nargin ~= 0
                    this.loadProtocol(protocolFile);
                end
            catch err
                error('Creating object ProtoDef failed! %s', err.message)
            end
        end
        
        function loadProtocol(this, protocol)
            try
                if ischar(protocol)
                    protocol = load(protocol);
                end
                this.protocolVersion = protocol.protocolVersion;
                this.TrialsDefinition = protocol.TrialsDefinition;
                this.TrialSequence = protocol.TrialSequence;
                this.Sounds = protocol.Sounds;
                this.maxProtocolDuration = protocol.maxProtocolDuration;
                this.finishLastTrial = protocol.finishLastTrial;
            catch err
                warning('Loading protocol failed! %s', err.message)
            end
        end
        
        function setMaxProtocolDuration(this, duration)
            this.maxProtocolDuration = duration;
        end
        
        function setFinishLastTrial(this, value)
            this.finishLastTrial = value;
        end
        
        function idx = getTrialIdx(this, trial)
        try
            nTrials = size(this.TrialsDefinition);
            nTrials = nTrials(2);
            if isnumeric(trial)                                      
                if trial > 0 && trial <= nTrials
                    idx = trial;
                else
                    warning('Specified invalid trial index! %d', trial)
                    idx = -1;
                end
            else                
                idx = -1;
                for i = 1:nTrials
                    if(strcmp(this.TrialsDefinition{i}.name, trial))
                        idx = i;
                        return
                    end
                end                
            end
        catch err
            warning('Cant find trial %s. Error msg %s', trial, err.message)
        end
        end
        
        function trial = getTrial(this, trialName)
        try
            trialIdx = this.getTrialIdx(trialName);
            if( trialIdx < 0)
                fprintf('Cant find trial %s!\n', trialName)
                return
            end
            trial = this.TrialsDefinition{trialIdx};
        catch err
            warning('Getting the trial failed! %s', err.message)
        end
        end
        
        function trialList = getTrialList(this)
        try            
            nTrials = length(this.TrialsDefinition);            
            trialList = cell(nTrials, 1);
            for i = 1:nTrials
                trialList{i} = this.TrialsDefinition{i}.name;            
            end
        catch err
            warning('Getting the trial list failed! %s', err.message)
        end
        end 
        
        function protocolSequence = getProtocolSequence(this)
        try            
            nTrials = length(this.TrialSequence);            
            protocolSequence = cell(nTrials, 1);
            for i = 1:nTrials
                trialIdx = this.TrialSequence{i};
                protocolSequence{i} = this.TrialsDefinition{trialIdx}.name;            
            end
        catch err
            warning('Getting the trial list failed! %s', err.message)
        end
        end 
        
        function idx = getPeriodIdx(this, trialName, periodName)
            try
                idx = -1;
                trialidx = this.getTrialIdx(trialName);
                if( trialidx < 0)
                    warning('Invalid trial name %s', trialName)                    
                    return
                end
                periodList = this.getPeriodList(trialName);
                idx = strmatch(periodName, periodList, 'exact');
                if isempty(idx)
                    idx = -1;
                end                
            catch err
                warning('Cant find period idx %s\\%s. Error msg %s', trialName, periodName, err.message)
            end
        end        
        
        function result = isValidPeriodByName(this, trialName, periodName)
            if getPeriodIdx(this, trialName, periodName) >= 0
                result = true;
            else
                result = false;
            end        
        end
        
        function result = isValidTrial(this, trial)
            try
                result = false;
                if isnumeric(trial)
                    nTrials = size(this.TrialsDefinition);
                    nTrials = nTrials(2);
                    
                    if trial > 0 && trial <= nTrials
                        result = true;                        
                    end
                else
                    tidx = this.getTrialIdx(trial);
                    if tidx > 0                        
                        result = true;
                    end
                end
            catch err
                warning('Cant find period idx %s\\%s. Error msg %s', trial, periodName, err.message)
            end
        end
                
        function result = isValidPeriod(this, trialIdx, periodIdx)
            try
                result = false;
                if this.isValidTrial(trialIdx) && isnumeric(periodIdx)                    
                    nPeriods = size(this.TrialsDefinition{trialIdx}.periods);
                    nPeriods = nPeriods(2);
                    if periodIdx > 0 && periodIdx <= nPeriods
                        result = true;                        
                    end
                end
            catch err
                warning('Cant find period idx %s\\%s. Error msg %s', trialIdx, periodIdx, err.message)
            end
        end
        
        function [period, trialIdx, periodIdx] = getPeriod(this, trialName, periodName)
           try                
                trialIdx = this.getTrialIdx(trialName);
                periodIdx = this.getPeriodIdx(trialName, periodName);
                if( trialIdx >= 0 && periodIdx >= 0)
                    period = this.TrialsDefinition{trialIdx}.periods{periodIdx};
                else
                    period = 0;
                end
            catch err
                warning('Cant find period %s\\%s. Error msg %s', trialName, periodName, err.message)
            end 
        end
        
        function success = removePeriod(this, trialName, periodName)
            success = false;
            trialIdx = this.getTrialIdx(trialName);
            periodIdx = this.getPeriodIdx(trialName, periodName);            
            if this.isValidPeriod(trialIdx, periodIdx)
                oldPeriods = this.TrialsDefinition{trialIdx}.periods;
                newPeriods = ProtoDef.removeStructElement(oldPeriods, periodIdx);
                this.TrialsDefinition{trialIdx}.periods = newPeriods;
                success = true;
            end
        end
      
        
        function success = removeTrial(this, trialName)
            success = false;            
            if this.isValidTrial(trialName)
                trialIdx = this.getTrialIdx(trialName);                
                newTrials = ProtoDef.removeStructElement(this.TrialsDefinition, trialIdx);
                this.TrialsDefinition = newTrials;
                
                %Remove this Trial entries from the protocol list
                newSequence = {};
                for i=1:length(this.TrialSequence)
                    if this.TrialSequence{i} ~= trialIdx
                        newSequence{end+1} = this.TrialSequence{i};
                    end
                end
                this.TrialSequence = newSequence;
                success = true;
            end
        end
        
        function success = removeSequence(this, idx)
            try
                import ProtoDef.ProtoDef
                success = false;
                if idx > 0 && idx <= length(this.TrialSequence)
                    newSquence = ProtoDef.removeStructElement(this.TrialSequence, idx);
                    this.TrialSequence = newSquence;
                    success = true;
                end
            catch err
                warning('Removing Trial from Protocol failed! %s', err.message)
            end
        end
        
        function periodList = getPeriodList(this, trialName)            
        try                
            periodList = {};
            if this.isValidTrial(trialName)
                trialIdx = this.getTrialIdx(trialName);
                trial = this.TrialsDefinition{trialIdx};
                nPeriods = length(trial.periods);
                periodList = cell(nPeriods, 1);
                for i = 1:nPeriods
                   periodList{i} =  trial.periods{i}.name;
                end
            end            
        catch err
            warning('Cant find period %s\\%s. Error msg %s', trialName, periodName, err.message)
        end 
        end

        function periodName = addNewPeriod(this, trial)
        try
            import ProtoDef.ProtoDef
            if this.isValidTrial(trial)
               if isnumeric(trial)
                   tidx = trial;
               else
                   tidx = this.getTrialIdx(trial);
               end
               
               periodName = 'New Period Name';
               d = 0;
               while this.isValidPeriodByName(trial, [periodName num2str(d)])
                   d = d + 1;
               end
               periodName = [periodName num2str(d)];
               period = ProtoDef.createEmptyPeriod(periodName);
               this.TrialsDefinition{tidx}.periods{end+1} = this.TrialsDefinition{tidx}.periods{end}; %Make sure ITI stays the last Entry in the list
               this.TrialsDefinition{tidx}.periods{end-1} = period;
            else
                periodName = 'INVALID TRIAL!';
            end            
        catch err
            warning('Creating new Trial failed! %s\n', err.message)
        end
        end
        
        function addTrialToProtocol(this, trial)
        try
            idx = this.getTrialIdx(trial);
            this.TrialSequence{end+1} = idx;
        catch err
            warning('Adding trial to protcol failed! %s\n', err.message)
        end
        end
        
        
        function addMultipleTrialsToProtocol(this, trial)
        try
            prompt={'# Trial:'};
            title='Add multiple Trials';
            numlines=1;
            defaultanswer={'1'};
            answer=inputdlg(prompt,title,numlines,defaultanswer);
            number = str2double(answer{:});
            
            idx = this.getTrialIdx(trial);
            for i = 1:max(1, number)
                this.TrialSequence{end+1} = idx;
            end
        catch err
            warn('Getting name failed! %s\n', err.message)
        end
        end
        
        function trialName = addNewTrial(this)
        try            
            trialName = this.getTrialName();
            trial = this.createEmptyTrial(trialName);
            this.TrialsDefinition{end+1} = trial;                        
        catch err
            warning('Creating new Trial failed! %s\n', err.message)
        end
        end       
        
        function newProtocol = updatePeriod(this, trialIdx, periodIdx, updatedPeriod)
        try
            if this.isValidPeriod(trialIdx, periodIdx)               
               this.TrialsDefinition{trialIdx}.periods{periodIdx} = updatedPeriod;               
            end
            newProtocol = this;
        catch err
            warning('Upadating period failed! %s\n', err.message)
        end
        end               
        
        function sound = getSound(this, soundIdx)
        try
            sound = this.Sounds{soundIdx};
        catch err
            warn('Getting sound for idx %d failed! %s\n', soundIdx, err.message)
        end
        end
        
        function soundIdx = addSound(this, soundPath)
        try
            soundIdx = -1;
            nSounds = length(this.Sounds);
            
            %Try to find this sound and if found return its index            
            for i = 1:nSounds
                if strcmp(soundPath, this.Sounds{i})
                    soundIdx = i;
                    return
                end
            end
            
            %Add it to the sound list
            this.Sounds{end+1} = soundPath;
            soundIdx = nSounds + 1;                        
        catch err
            warn('Getting index of sound %s failed! %s\n', soundPath, err.message)
        end
        end
       
        
        function shuffleTrialOrder(this)            
            nTrials = length(this.TrialSequence);
            newOrder = randperm(nTrials);
            newTrialSequence = cell(1, nTrials);
            for i=1:nTrials
                newTrialSequence{i} = this.TrialSequence{newOrder(i)};
            end
            this.TrialSequence = newTrialSequence;
        end
        
        function protocol = getProtocol(this)
            protocol = {};
            protocol.protocolVersion = this.protocolVersion;
            protocol.TrialsDefinition = this.TrialsDefinition;
            protocol.TrialSequence = this.TrialSequence;
            protocol.Sounds = this.Sounds;
            protocol.maxProtocolDuration = this.maxProtocolDuration;
            protocol.finishLastTrial = this.finishLastTrial;
        end               
    end
    
end

